// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#define BCAST_LLKOP EltwiseBinaryType::ELWMUL
#define BCAST_DIM BroadcastType::COL

#include "compute_kernel_api/bcast.h"
#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/layernorm.h"
#include "compute_kernel_api/tile_move_copy.h"
#include "compute_kernel_api/eltwise_binary_sfpu.h"
#include "compute_kernel_api/eltwise_unary/rsqrt.h"
#include "compute_kernel_api/transpose_wh_dest.h"
#include "compute_kernel_api/eltwise_unary/binop_with_scalar.h"
#include "compute_kernel_api/welford.h"
#include "compute_kernel_api/transpose_wh.h"

ALWI void ACQ() { acquire_dst(); }
ALWI void REL() { release_dst(); }

namespace NAMESPACE {

void MAIN {
    uint32_t NCHt = get_arg_val<uint32_t>(0);
    constexpr uint32_t Wt = get_compile_time_arg_val(0);
    constexpr uint32_t blk = get_compile_time_arg_val(1);
    constexpr uint32_t do_gamma = get_compile_time_arg_val(2);
    constexpr uint32_t do_beta = get_compile_time_arg_val(3);
    constexpr bool FLOAT32_DTYPE = get_compile_time_arg_val(4) == 1;
    constexpr uint32_t W = get_compile_time_arg_val(5);
    constexpr uint32_t tile_width = get_compile_time_arg_val(6);
    constexpr bool fuse_pre_add = static_cast<bool>(get_compile_time_arg_val(8));

    constexpr uint32_t onetile = 1;

    constexpr auto cb_scaler = tt::CBIndex::c_2;  // single tile generated by the reader
    constexpr auto cb_eps = tt::CBIndex::c_3;     // single tile generated by the reader
    constexpr auto cb_in = tt::CBIndex::c_0;      // input x or a for fused pre-add (x=a+b)
    constexpr auto cb_inb = tt::CBIndex::c_1;     // input b for fused pre-add
    constexpr auto cb_out = tt::CBIndex::c_16;    // output
    constexpr auto cb_gamma = tt::CBIndex::c_5;
    constexpr auto cb_beta = tt::CBIndex::c_6;
    constexpr auto cb_xmm = tt::CBIndex::c_24;  // x - E[x]

    constexpr auto cb_ex = tt::CBIndex::c_18;      // E[x]
    constexpr auto cb_ex2 = tt::CBIndex::c_19;     // E[(x-E[x])^2]
    constexpr auto cb_ex2pe = tt::CBIndex::c_21;   // E[(x-E[x])^2]+eps
    constexpr auto cb_fusion = tt::CBIndex::c_22;  // stream gamma/beta
    constexpr auto cb_im_or_out = (do_gamma | do_beta) ? cb_fusion : cb_out;

    constexpr auto scaler0 = 0;

    //  Either in or in + b if doing fused pre-add
    constexpr auto cb_x = []() -> auto {
        if constexpr (fuse_pre_add) {
            return tt::CBIndex::c_23;
        } else {
            return cb_in;
        }
    }();

    constexpr int dst0 = 0;  // Input tile to Welford's algorithm
    constexpr int dst1 = 1;  // Partial E[x] result for Welford's
    constexpr int dst2 = 2;  // Partial Var[x] result for Welford's

    cb_wait_front(cb_scaler, 1);  // comes from the reader
    cb_wait_front(cb_eps, 1);     // comes from the reader

    if constexpr (fuse_pre_add) {
        binary_op_init_common(cb_in, cb_inb, cb_x);
        pack_reconfig_data_format(cb_x);
    } else {
        binary_op_init_common(cb_in, cb_scaler, cb_ex);
        pack_reconfig_data_format(cb_ex);
    }

    for (uint32_t ncht = 0; ncht < NCHt; ncht++) {
        if constexpr (fuse_pre_add) {
            // x = in + b
            add_tiles_init(cb_in, cb_inb);
            reconfig_data_format(cb_in, cb_inb);
            pack_reconfig_data_format(cb_x);
            for (uint32_t wt = 0; wt < Wt; wt += blk) {
                ACQ();
                cb_wait_front(cb_in, blk);
                cb_wait_front(cb_inb, blk);
                cb_reserve_back(cb_x, blk);
                for (uint32_t j = 0; j < blk; j++) {
                    add_tiles(cb_in, cb_inb, j, j, j);
                    pack_tile(j, cb_x);
                }
                REL();
                cb_push_back(cb_x, blk);  // push the sum into the same buffer
                cb_pop_front(cb_in, blk);
                cb_pop_front(cb_inb, blk);
            }
            reconfig_data_format(cb_in, cb_x, cb_inb, cb_scaler);
        }

        // Simultaneous calculation of E[x] and Var[x] using Welford's algorithm
        ACQ();

        uint32_t start_N = 0;
        transpose_wh_init_short(cb_x);
        welford_init();
        for (uint32_t wt = 0; wt < Wt; wt += blk) {
            cb_wait_front(cb_x, wt + blk);
            for (uint32_t j = 0; j < blk; j++) {
                // Welford's needs transposed input tile
                transpose_wh_tile(cb_x, wt + j, dst0);
                welford_tile<dst0, dst1, dst2, true, 0>(start_N, W, 0, {});
                start_N += tile_width;
            }
        }

        // Transpose dst1 and dst2 back to columns
        cb_reserve_back(cb_ex, onetile);
        cb_reserve_back(cb_ex2, onetile);

        pack_reconfig_data_format(cb_ex);
        pack_tile(dst1, cb_ex);
        pack_reconfig_data_format(cb_ex2);
        pack_tile(dst2, cb_ex2);

        tile_regs_commit();
        tile_regs_release();

        cb_push_back(cb_ex, onetile);
        cb_push_back(cb_ex2, onetile);

        cb_wait_front(cb_ex, onetile);
        cb_wait_front(cb_ex2, onetile);
        reconfig_data_format_srca(cb_ex);
        transpose_wh_init_short(cb_ex);

        tile_regs_acquire();
        transpose_wh_tile(cb_ex, 0, dst1);
        transpose_wh_tile(cb_ex2, 0, dst2);
        tile_regs_commit();

        cb_pop_front(cb_ex, onetile);
        cb_pop_front(cb_ex2, onetile);

        cb_reserve_back(cb_ex, onetile);
        cb_reserve_back(cb_ex2, onetile);
        pack_reconfig_data_format(cb_ex);

        tile_regs_wait();
        pack_tile(dst1, cb_ex);
        pack_reconfig_data_format(cb_ex2);
        pack_tile(dst2, cb_ex2);
        tile_regs_release();

        cb_push_back(cb_ex, onetile);
        cb_push_back(cb_ex2, onetile);
        REL();

        // x - E[x]
        // Reuse cb_x since we didn't pop anything from it
        if constexpr (FLOAT32_DTYPE) {
            reconfig_data_format(cb_x, cb_ex);
        }
        cb_wait_front(cb_ex, onetile);  // should have 1 tile
        cb_reserve_back(cb_xmm, Wt);
        sub_bcast_cols_init_short(cb_x, cb_ex);
        for (uint32_t wt = 0; wt < Wt; wt += blk) {
            ACQ();
            for (uint32_t wtr = 0; wtr < blk; wtr++) {
                sub_tiles_bcast_cols(cb_x, cb_ex, wt + wtr, 0, wtr);
                pack_tile(wtr, cb_xmm);
            }
            cb_push_back(cb_xmm, blk);
            REL();
        }
        cb_pop_front(cb_ex, 1);
        cb_pop_front(cb_x, Wt);
        cb_wait_front(cb_xmm, Wt);

        if constexpr (!fuse_pre_add) {
            reconfig_data_format_srca(cb_x, cb_xmm);
        }

        // Var(x) + eps
        binary_op_init_common(cb_ex2, cb_eps, cb_ex2pe);
        if constexpr (FLOAT32_DTYPE) {
            reconfig_data_format(cb_ex2, cb_eps);
        }
        cb_wait_front(cb_ex2, onetile);  // should have 1 tile
        ACQ();
        add_tiles_init(cb_ex2, cb_eps);
        add_tiles(cb_ex2, cb_eps, 0, 0, dst0);

        cb_reserve_back(cb_ex2pe, onetile);
        rsqrt_tile_init();
        rsqrt_tile(dst0);
        pack_tile(dst0, cb_ex2pe);
        cb_push_back(cb_ex2pe, onetile);
        REL();
        cb_pop_front(cb_ex2, onetile);

        // Remainder of the layernorm operation
        // norm(x) * gamma + beta,
        // where norm(x) is:
        // (x - E[x]) / sqrt(E[(x-E[x])^2] + eps)
        cb_wait_front(cb_ex2pe, onetile);
        for (uint32_t wt = 0; wt < Wt; wt += blk) {
            reconfig_data_format(cb_xmm, cb_ex2pe);
            if constexpr (do_gamma == 0 && do_beta == 0) {
                pack_reconfig_data_format(cb_out);
            } else {
                pack_reconfig_data_format(cb_fusion);
            }
            cb_reserve_back(cb_im_or_out, blk);

            ACQ();
            mul_bcast_cols_init_short(cb_xmm, cb_ex2pe);
            for (uint32_t wtr = 0; wtr < blk; wtr++) {
                // cb_xmm[wt+wtr] since we pop Wt from cb_xmm after the entire loop
                mul_tiles_bcast_cols(cb_xmm, cb_ex2pe, wt + wtr, 0, wtr);
                pack_tile(wtr, cb_im_or_out);  // pack either to intermediate (cb_fusion or out0)
            }
            cb_push_back(cb_im_or_out, blk);  // if no gamma/beta are provided, this will be passed on to the writer
            REL();

            if constexpr (do_gamma) {
                if constexpr (do_beta == 0) {
                    pack_reconfig_data_format(cb_out);
                }
                reconfig_data_format_srcb(cb_ex2pe, cb_gamma);
                ACQ();
                uint32_t cb_outg = do_beta ? cb_fusion : cb_out;
                mul_bcast_rows_init_short(cb_fusion, cb_gamma);
                cb_reserve_back(cb_outg, blk);
                cb_wait_front(cb_gamma, wt + blk);  // we don't pop, TODO: only wait on first ht
                cb_wait_front(cb_fusion, blk);
                for (uint32_t wtr = 0; wtr < blk; wtr++) {
                    mul_tiles_bcast_rows(cb_fusion, cb_gamma, wtr, wt + wtr, wtr);  // tile *= 1/(sum(exp(x)))
                    pack_tile(wtr, cb_outg);  // pack either to intermediate (cb_fusion or out0)
                }
                cb_pop_front(cb_fusion, blk);
                // we don't pop gamma
                cb_push_back(cb_outg, blk);
                // We don't pop gamma since it's 1,1,1,Wt and we reuse it for all NCHt
                REL();
            }
            if constexpr (do_beta) {
                pack_reconfig_data_format(cb_out);
                if constexpr (do_gamma) {
                    reconfig_data_format_srcb(cb_gamma, cb_beta);
                } else {
                    reconfig_data_format_srcb(cb_ex2pe, cb_beta);
                }
                ACQ();
                add_bcast_rows_init_short(cb_fusion, cb_beta);
                cb_reserve_back(cb_out, blk);
                cb_wait_front(cb_beta, wt + blk);  // TODO: optimization - only wait on first ht
                cb_wait_front(cb_fusion, blk);
                for (uint32_t wtr = 0; wtr < blk; wtr++) {
                    add_tiles_bcast_rows(cb_fusion, cb_beta, wtr, wt + wtr, wtr);  // tile *= 1/(sum(exp(x)))
                    pack_tile(wtr, cb_out);  // pack either to intermediate (cb_fusion or out0)
                }
                cb_pop_front(cb_fusion, blk);
                // We don't pop beta since it's 1,1,1,Wt and we reuse it for all NCHt
                cb_push_back(cb_out, blk);
                REL();
            }
        }
        cb_pop_front(cb_ex2pe, onetile);
        cb_pop_front(cb_xmm, Wt);

    }  // NCHt loop
}
}  // namespace NAMESPACE
